from keras.datasets import mnist

import matplotlib.pyplot as plt


def load_mnist_data(normalize=False):
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train.reshape(x_train.shape + (1,))
    x_test = x_test.reshape(x_test.shape + (1,))

    if normalize:
        x_train = x_train.astype('float32') / 255.
        x_test = x_test.astype('float32') / 255.

    return (x_train, y_train), (x_test, y_test)


def plot_image_rows(images_list, title_list):
    rows = len(images_list)
    cols = len(images_list[0])
    
    def plot_image_row(images, title):
        plt.figure(figsize=(cols, 3))
        plt.gcf().suptitle(title)
        for i, img in enumerate(images):
            plt.subplot(rows, cols, i + 1)
            plt.imshow(img[:,:,0], cmap='Greys_r')
            plt.axis('off')

    for images, title in zip(images_list, title_list):
        plot_image_row(images, title)

        
def plot_laplacian_variances(lvs_1, lvs_2, lvs_3, title):
    plt.hist(lvs_1, alpha=0.2, bins=50, label='Original images');
    plt.hist(lvs_2, alpha=0.2, bins=50, label='Images generated by plain VAE');
    plt.hist(lvs_3, alpha=0.2, bins=50, label='Images generated by DFC VAE');
    plt.xlabel('Laplacian variance')
    plt.title(title)
    plt.legend();
